#include<bits/stdc++.h>
#define int long long
using namespace std;
int n;
int m;
int a,b;
const int N=998244353;
signed main()
{
    freopen("clown.in","r",stdin);
    freopen("clown.out","w",stdout);
    cin>>m;
    cin>>n;
    cin>>a>>b;
    int ans=0;
    ans=a*a%N*a%N*b%N*b%N*b%N;
    ans-=3*a*a%N*b%N*b%N;
    ans+=2*a*b%N;
    ans+=N;
    ans%=N;
    cout<<ans;
    return 0;
}
